import os
from typing import List

import numpy as np
import uvicorn
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, models

# 环境变量传入
sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk')

# 创建一个FastAPI实例
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 创建一个HTTPBearer实例
security = HTTPBearer()
# 加载预训练的 Transformer 模型
transformer_model = models.Transformer('./m3e-large', cache_dir='./cache')

# 创建 Mean Pooling 层
pooling_model = models.Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode='mean')

# 构建 SentenceTransformer 模型
model = SentenceTransformer(modules=[transformer_model, pooling_model])


class EmbeddingRequest(BaseModel):
    input: List[str]


class EmbeddingResponse(BaseModel):
    data: list
    dimension: int


@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
    if credentials.credentials != sk_key:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid authorization code",
        )

    # 计算嵌入向量和tokens数量 
    embeddings = [model.encode(text) for text in request.input]
    # 归一化处理
    embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings]
    # 将numpy数组转换为列表
    embeddings = [embedding.tolist() for embedding in embeddings]

    response = {
        "data": [
            {
                "embedding": embedding,
                "index": index
            } for index, embedding in enumerate(embeddings)
        ],
        "dimension": len(embeddings[0])
    }

    return response


if __name__ == "__main__":
    uvicorn.run("localEmbedding:app", host='0.0.0.0', port=6009, workers=2)
